pub mod batch;
pub mod feedback;
mod gnn;
mod replay;
pub use batch::{
BatchInput, BatchJob, BatchScheduler, EntryMetadata, Insight, InsightStore, JobRun, JobStatus,
JobType, KnowledgeClass, RelationshipType, Trend,
};
pub use feedback::{
FeedbackCollector, FeedbackConfig, FeedbackProcessor, FeedbackSignal, ProcessedFeedback,
QueryId, SessionId, SignalType,
};
pub use gnn::GnnLayer;
pub use replay::ReplayBuffer;
use crate::core::SearchResult;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use uuid::Uuid;
pub struct LearningEngine {
dimensions: usize,
learning_rate: f32,
relevance_weights: Vec<f32>,
gnn_layer: GnnLayer,
replay_buffer: ReplayBuffer<Experience>,
#[allow(dead_code)]
query_patterns: VecDeque<QueryPattern>,
entry_scores: DashMap<Uuid, f32>,
fisher_diagonal: Vec<f32>,
query_count: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct QueryPattern {
query_embedding: Vec<f32>,
result_embeddings: Vec<(Vec<f32>, f32)>,
timestamp: u64,
}
impl LearningEngine {
pub fn new(dimensions: usize, learning_rate: f32) -> Self {
Self {
dimensions,
learning_rate,
relevance_weights: vec![1.0; dimensions],
gnn_layer: GnnLayer::new(dimensions, dimensions * 2, 4),
replay_buffer: ReplayBuffer::new(10000),
query_patterns: VecDeque::with_capacity(1000),
entry_scores: DashMap::new(),
fisher_diagonal: vec![0.0; dimensions],
query_count: 0,
}
}
pub fn rerank(
&self,
query_embedding: &[f32],
mut candidates: Vec<(Uuid, f32)>,
vectors: &DashMap<Uuid, Vec<f32>>,
) -> Vec<(Uuid, f32)> {
let neighbors: Vec<Vec<f32>> = candidates
.iter()
.take(10)
.filter_map(|(id, _)| vectors.get(id).map(|v| v.clone()))
.collect();
let edge_weights: Vec<f32> = candidates
.iter()
.take(10)
.map(|(_, d)| 1.0 - d.min(1.0))
.collect();
let transformed_query = self
.gnn_layer
.forward(query_embedding, &neighbors, &edge_weights);
for (id, distance) in &mut candidates {
if let Some(vector) = vectors.get(id) {
let weighted_distance = self.weighted_distance(&transformed_query, &vector);
let entry_boost = self.entry_scores.get(id).map_or(1.0, |s| *s);
*distance = weighted_distance / entry_boost;
}
}
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
candidates
}
fn weighted_distance(&self, a: &[f32], b: &[f32]) -> f32 {
let mut weighted_dot = 0.0f32;
let mut weighted_norm_a = 0.0f32;
let mut weighted_norm_b = 0.0f32;
for i in 0..a.len().min(b.len()).min(self.dimensions) {
let w = self.relevance_weights[i];
weighted_dot += a[i] * b[i] * w;
weighted_norm_a += a[i] * a[i] * w;
weighted_norm_b += b[i] * b[i] * w;
}
let norm = (weighted_norm_a * weighted_norm_b).sqrt();
if norm > 0.0 {
1.0 - (weighted_dot / norm)
} else {
1.0
}
}
pub fn record_query(&mut self, query_embedding: &[f32], results: &[SearchResult]) {
if results.is_empty() {
return;
}
self.query_count += 1;
for (rank, result) in results.iter().enumerate() {
self.replay_buffer.add(Experience {
query: query_embedding.to_vec(),
result_id: result.entry.id,
rank: rank as u32,
score: result.score,
});
}
if self.query_count.is_multiple_of(10) {
self.learn_from_replay();
}
}
pub fn record_feedback(&mut self, result_embedding: &[f32], positive: bool) {
let adjustment = if positive {
self.learning_rate
} else {
-self.learning_rate * 0.5
};
for (i, &val) in result_embedding.iter().enumerate() {
if i < self.dimensions {
let delta = adjustment * val.abs();
let ewc_factor = 1.0 / (1.0 + self.fisher_diagonal[i]);
self.relevance_weights[i] =
(self.relevance_weights[i] + delta * ewc_factor).clamp(0.1, 10.0);
}
}
self.update_fisher(result_embedding);
}
fn update_fisher(&mut self, embedding: &[f32]) {
for (i, &val) in embedding.iter().enumerate() {
if i < self.dimensions {
self.fisher_diagonal[i] = 0.99 * self.fisher_diagonal[i] + 0.01 * val * val;
}
}
}
fn learn_from_replay(&mut self) {
let samples = self.replay_buffer.sample(32);
for experience in samples {
let target_boost = 1.0 + (1.0 / (1.0 + experience.rank as f32));
self.entry_scores
.entry(experience.result_id)
.and_modify(|s| {
*s = f32::midpoint(*s, target_boost);
})
.or_insert(target_boost);
self.gnn_layer
.update(&experience.query, target_boost, self.learning_rate);
}
}
pub fn query_count(&self) -> u64 {
self.query_count
}
pub fn stats(&self) -> LearningStats {
let avg_weight: f32 = self.relevance_weights.iter().sum::<f32>() / self.dimensions as f32;
let weight_variance: f32 = self
.relevance_weights
.iter()
.map(|w| (w - avg_weight).powi(2))
.sum::<f32>()
/ self.dimensions as f32;
LearningStats {
query_count: self.query_count,
replay_buffer_size: self.replay_buffer.len(),
learned_entries: self.entry_scores.len(),
avg_relevance_weight: avg_weight,
weight_variance,
}
}
}
#[derive(Debug, Clone)]
struct Experience {
query: Vec<f32>,
result_id: Uuid,
rank: u32,
#[allow(dead_code)]
score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningStats {
pub query_count: u64,
pub replay_buffer_size: usize,
pub learned_entries: usize,
pub avg_relevance_weight: f32,
pub weight_variance: f32,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::KnowledgeEntry;
fn fake_results(n: usize) -> Vec<SearchResult> {
(0..n)
.map(|i| {
let e = KnowledgeEntry::new(format!("t{i}"), "c");
SearchResult::new(e, 0.9 - i as f32 * 0.05, 0.1 * i as f32)
})
.collect()
}
#[test]
fn test_learning_engine_creation() {
let engine = LearningEngine::new(128, 0.01);
assert_eq!(engine.dimensions, 128);
assert_eq!(engine.query_count, 0);
}
#[test]
fn test_feedback_updates_weights() {
let mut engine = LearningEngine::new(64, 0.1);
let initial_weights = engine.relevance_weights.clone();
let embedding = vec![0.5; 64];
engine.record_feedback(&embedding, true);
assert_ne!(engine.relevance_weights, initial_weights);
}
#[test]
fn negative_feedback_also_updates_weights() {
let mut engine = LearningEngine::new(32, 0.2);
let before = engine.relevance_weights.clone();
engine.record_feedback(&[0.4; 32], false);
assert_ne!(engine.relevance_weights, before);
for w in &engine.relevance_weights {
assert!(*w >= 0.1 && *w <= 10.0);
}
}
#[test]
fn record_query_empty_is_noop() {
let mut engine = LearningEngine::new(16, 0.1);
engine.record_query(&[0.0; 16], &[]);
assert_eq!(engine.query_count(), 0);
}
#[test]
fn record_query_increments_and_triggers_replay_learning() {
let mut engine = LearningEngine::new(16, 0.1);
let q = vec![0.3; 16];
let results = fake_results(3);
for _ in 0..12 {
engine.record_query(&q, &results);
}
assert_eq!(engine.query_count(), 12);
let stats = engine.stats();
assert_eq!(stats.query_count, 12);
assert!(stats.replay_buffer_size > 0);
assert!(stats.learned_entries > 0);
assert!(stats.avg_relevance_weight > 0.0);
}
#[test]
fn rerank_changes_candidate_order() {
let engine = LearningEngine::new(16, 0.1);
let vectors: DashMap<Uuid, Vec<f32>> = DashMap::new();
let mut candidates = Vec::new();
for i in 0..3 {
let id = Uuid::new_v4();
let mut v = vec![0.0; 16];
v[i % 16] = 1.0;
vectors.insert(id, v);
candidates.push((id, 0.5));
}
let q = vec![0.1; 16];
let reranked = engine.rerank(&q, candidates.clone(), &vectors);
assert_eq!(reranked.len(), candidates.len());
for w in reranked.windows(2) {
assert!(w[0].1 <= w[1].1 || w[0].1.is_nan() || w[1].1.is_nan());
}
}
#[test]
fn rerank_empty_candidates_returns_empty() {
let engine = LearningEngine::new(16, 0.1);
let vectors: DashMap<Uuid, Vec<f32>> = DashMap::new();
let out = engine.rerank(&[0.0; 16], Vec::new(), &vectors);
assert!(out.is_empty());
}
}