1pub mod batch;
45pub mod feedback;
46mod gnn;
47mod replay;
48
49pub use batch::{
50 BatchInput, BatchJob, BatchScheduler, EntryMetadata, Insight, InsightStore, JobRun, JobStatus,
51 JobType, KnowledgeClass, RelationshipType, Trend,
52};
53pub use feedback::{
54 FeedbackCollector, FeedbackConfig, FeedbackProcessor, FeedbackSignal, ProcessedFeedback,
55 QueryId, SessionId, SignalType,
56};
57pub use gnn::GnnLayer;
58pub use replay::ReplayBuffer;
59
60use crate::core::SearchResult;
61use dashmap::DashMap;
62use serde::{Deserialize, Serialize};
63use std::collections::VecDeque;
64use uuid::Uuid;
65
66pub struct LearningEngine {
73 dimensions: usize,
75
76 learning_rate: f32,
78
79 relevance_weights: Vec<f32>,
81
82 gnn_layer: GnnLayer,
84
85 replay_buffer: ReplayBuffer<Experience>,
87
88 #[allow(dead_code)]
90 query_patterns: VecDeque<QueryPattern>,
91
92 entry_scores: DashMap<Uuid, f32>,
94
95 fisher_diagonal: Vec<f32>,
97
98 query_count: u64,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104struct QueryPattern {
105 query_embedding: Vec<f32>,
107
108 result_embeddings: Vec<(Vec<f32>, f32)>,
110
111 timestamp: u64,
113}
114
115impl LearningEngine {
116 pub fn new(dimensions: usize, learning_rate: f32) -> Self {
118 Self {
119 dimensions,
120 learning_rate,
121 relevance_weights: vec![1.0; dimensions],
122 gnn_layer: GnnLayer::new(dimensions, dimensions * 2, 4),
123 replay_buffer: ReplayBuffer::new(10000),
124 query_patterns: VecDeque::with_capacity(1000),
125 entry_scores: DashMap::new(),
126 fisher_diagonal: vec![0.0; dimensions],
127 query_count: 0,
128 }
129 }
130
131 pub fn rerank(
133 &self,
134 query_embedding: &[f32],
135 mut candidates: Vec<(Uuid, f32)>,
136 vectors: &DashMap<Uuid, Vec<f32>>,
137 ) -> Vec<(Uuid, f32)> {
138 let neighbors: Vec<Vec<f32>> = candidates
140 .iter()
141 .take(10)
142 .filter_map(|(id, _)| vectors.get(id).map(|v| v.clone()))
143 .collect();
144
145 let edge_weights: Vec<f32> = candidates
146 .iter()
147 .take(10)
148 .map(|(_, d)| 1.0 - d.min(1.0))
149 .collect();
150
151 let transformed_query = self
152 .gnn_layer
153 .forward(query_embedding, &neighbors, &edge_weights);
154
155 for (id, distance) in &mut candidates {
157 if let Some(vector) = vectors.get(id) {
158 let weighted_distance = self.weighted_distance(&transformed_query, &vector);
160
161 let entry_boost = self.entry_scores.get(id).map_or(1.0, |s| *s);
163
164 *distance = weighted_distance / entry_boost;
165 }
166 }
167
168 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
170
171 candidates
172 }
173
174 fn weighted_distance(&self, a: &[f32], b: &[f32]) -> f32 {
176 let mut weighted_dot = 0.0f32;
177 let mut weighted_norm_a = 0.0f32;
178 let mut weighted_norm_b = 0.0f32;
179
180 for i in 0..a.len().min(b.len()).min(self.dimensions) {
181 let w = self.relevance_weights[i];
182 weighted_dot += a[i] * b[i] * w;
183 weighted_norm_a += a[i] * a[i] * w;
184 weighted_norm_b += b[i] * b[i] * w;
185 }
186
187 let norm = (weighted_norm_a * weighted_norm_b).sqrt();
188 if norm > 0.0 {
189 1.0 - (weighted_dot / norm)
190 } else {
191 1.0
192 }
193 }
194
195 pub fn record_query(&mut self, query_embedding: &[f32], results: &[SearchResult]) {
197 if results.is_empty() {
198 return;
199 }
200
201 self.query_count += 1;
202
203 for (rank, result) in results.iter().enumerate() {
205 self.replay_buffer.add(Experience {
206 query: query_embedding.to_vec(),
207 result_id: result.entry.id,
208 rank: rank as u32,
209 score: result.score,
210 });
211 }
212
213 if self.query_count.is_multiple_of(10) {
215 self.learn_from_replay();
216 }
217 }
218
219 pub fn record_feedback(&mut self, result_embedding: &[f32], positive: bool) {
221 let adjustment = if positive {
222 self.learning_rate
223 } else {
224 -self.learning_rate * 0.5
225 };
226
227 for (i, &val) in result_embedding.iter().enumerate() {
229 if i < self.dimensions {
230 let delta = adjustment * val.abs();
232
233 let ewc_factor = 1.0 / (1.0 + self.fisher_diagonal[i]);
235
236 self.relevance_weights[i] =
237 (self.relevance_weights[i] + delta * ewc_factor).clamp(0.1, 10.0);
238 }
239 }
240
241 self.update_fisher(result_embedding);
243 }
244
245 fn update_fisher(&mut self, embedding: &[f32]) {
247 for (i, &val) in embedding.iter().enumerate() {
248 if i < self.dimensions {
249 self.fisher_diagonal[i] = 0.99 * self.fisher_diagonal[i] + 0.01 * val * val;
251 }
252 }
253 }
254
255 fn learn_from_replay(&mut self) {
257 let samples = self.replay_buffer.sample(32);
258
259 for experience in samples {
260 let target_boost = 1.0 + (1.0 / (1.0 + experience.rank as f32));
262
263 self.entry_scores
265 .entry(experience.result_id)
266 .and_modify(|s| {
267 *s = f32::midpoint(*s, target_boost);
268 })
269 .or_insert(target_boost);
270
271 self.gnn_layer
273 .update(&experience.query, target_boost, self.learning_rate);
274 }
275 }
276
277 pub fn query_count(&self) -> u64 {
279 self.query_count
280 }
281
282 pub fn stats(&self) -> LearningStats {
284 let avg_weight: f32 = self.relevance_weights.iter().sum::<f32>() / self.dimensions as f32;
285 let weight_variance: f32 = self
286 .relevance_weights
287 .iter()
288 .map(|w| (w - avg_weight).powi(2))
289 .sum::<f32>()
290 / self.dimensions as f32;
291
292 LearningStats {
293 query_count: self.query_count,
294 replay_buffer_size: self.replay_buffer.len(),
295 learned_entries: self.entry_scores.len(),
296 avg_relevance_weight: avg_weight,
297 weight_variance,
298 }
299 }
300}
301
302#[derive(Debug, Clone)]
304struct Experience {
305 query: Vec<f32>,
306 result_id: Uuid,
307 rank: u32,
308 #[allow(dead_code)]
310 score: f32,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct LearningStats {
316 pub query_count: u64,
318 pub replay_buffer_size: usize,
320 pub learned_entries: usize,
322 pub avg_relevance_weight: f32,
324 pub weight_variance: f32,
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use crate::core::KnowledgeEntry;
332
333 fn fake_results(n: usize) -> Vec<SearchResult> {
334 (0..n)
335 .map(|i| {
336 let e = KnowledgeEntry::new(format!("t{i}"), "c");
337 SearchResult::new(e, 0.9 - i as f32 * 0.05, 0.1 * i as f32)
338 })
339 .collect()
340 }
341
342 #[test]
343 fn test_learning_engine_creation() {
344 let engine = LearningEngine::new(128, 0.01);
345 assert_eq!(engine.dimensions, 128);
346 assert_eq!(engine.query_count, 0);
347 }
348
349 #[test]
350 fn test_feedback_updates_weights() {
351 let mut engine = LearningEngine::new(64, 0.1);
352 let initial_weights = engine.relevance_weights.clone();
353
354 let embedding = vec![0.5; 64];
355 engine.record_feedback(&embedding, true);
356
357 assert_ne!(engine.relevance_weights, initial_weights);
359 }
360
361 #[test]
362 fn negative_feedback_also_updates_weights() {
363 let mut engine = LearningEngine::new(32, 0.2);
364 let before = engine.relevance_weights.clone();
365 engine.record_feedback(&[0.4; 32], false);
366 assert_ne!(engine.relevance_weights, before);
367 for w in &engine.relevance_weights {
369 assert!(*w >= 0.1 && *w <= 10.0);
370 }
371 }
372
373 #[test]
374 fn record_query_empty_is_noop() {
375 let mut engine = LearningEngine::new(16, 0.1);
376 engine.record_query(&[0.0; 16], &[]);
377 assert_eq!(engine.query_count(), 0);
378 }
379
380 #[test]
381 fn record_query_increments_and_triggers_replay_learning() {
382 let mut engine = LearningEngine::new(16, 0.1);
383 let q = vec![0.3; 16];
384 let results = fake_results(3);
385 for _ in 0..12 {
387 engine.record_query(&q, &results);
388 }
389 assert_eq!(engine.query_count(), 12);
390 let stats = engine.stats();
391 assert_eq!(stats.query_count, 12);
392 assert!(stats.replay_buffer_size > 0);
393 assert!(stats.learned_entries > 0);
394 assert!(stats.avg_relevance_weight > 0.0);
395 }
396
397 #[test]
398 fn rerank_changes_candidate_order() {
399 let engine = LearningEngine::new(16, 0.1);
400 let vectors: DashMap<Uuid, Vec<f32>> = DashMap::new();
401 let mut candidates = Vec::new();
402 for i in 0..3 {
403 let id = Uuid::new_v4();
404 let mut v = vec![0.0; 16];
405 v[i % 16] = 1.0;
406 vectors.insert(id, v);
407 candidates.push((id, 0.5));
408 }
409 let q = vec![0.1; 16];
410 let reranked = engine.rerank(&q, candidates.clone(), &vectors);
411 assert_eq!(reranked.len(), candidates.len());
412 for w in reranked.windows(2) {
414 assert!(w[0].1 <= w[1].1 || w[0].1.is_nan() || w[1].1.is_nan());
415 }
416 }
417
418 #[test]
419 fn rerank_empty_candidates_returns_empty() {
420 let engine = LearningEngine::new(16, 0.1);
421 let vectors: DashMap<Uuid, Vec<f32>> = DashMap::new();
422 let out = engine.rerank(&[0.0; 16], Vec::new(), &vectors);
423 assert!(out.is_empty());
424 }
425}